Fix private memory size too large in sample_recovered_tokens_kernel#115
Conversation
Signed-off-by: Xin Li <xin.li@metax-tech.com>
Signed-off-by: Xin Li <xin.li@metax-tech.com>
There was a problem hiding this comment.
Code Review
This pull request aims to fix a memory size or pointer value too large error in a Triton kernel. The main change is the introduction of a new, blocked implementation for sample_recovered_tokens_kernel to manage memory usage better. My review focuses on ensuring the new kernel is robust against potential integer overflows during pointer arithmetic, which is a likely cause of the original error. I've identified two areas where 32-bit integer overflows could still occur and have provided suggestions to cast to 64-bit integers to prevent this, enhancing the correctness and reliability of the fix.
| max_prob = -float('inf') | ||
| best_token_id = 0 | ||
|
|
||
| for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE): | ||
| block_end = min(block_start + BLOCK_SIZE, vocab_size) | ||
|
|
||
| vocab_offset = tl.arange(0, BLOCK_SIZE) | ||
| mask = vocab_offset < block_end - block_start | ||
|
|
||
| if NO_DRAFT_PROBS: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
| prob = tl.load( | ||
| target_probs_ptr + (start_idx + pos) * vocab_size + | ||
| block_start + vocab_offset, | ||
| mask=(mask & (vocab_offset + block_start != draft_token_id)), | ||
| other=0) | ||
|
|
||
| else: | ||
| draft_prob = tl.load(draft_probs_ptr + | ||
| (start_idx + pos) * vocab_size + block_start + | ||
| vocab_offset, | ||
| mask=mask, | ||
| other=0) | ||
| target_prob = tl.load(target_probs_ptr + | ||
| (start_idx + pos) * vocab_size + | ||
| block_start + vocab_offset, | ||
| mask=mask, | ||
| other=0) | ||
| prob = tl.maximum(target_prob - draft_prob, 0) |
There was a problem hiding this comment.
The pointer offset calculation (start_idx + pos) * vocab_size is repeated and may suffer from 32-bit integer overflow if start_idx is loaded as a 32-bit integer and multiplied by a large vocab_size. This can lead to incorrect memory access and is a likely cause for the pointer value too large to fit in 32 bit error.
To ensure correctness and improve readability, it's better to calculate the base offset once outside the loop, explicitly casting to tl.int64 to prevent any potential overflow.
token_idx = start_idx + pos
# Cast to int64 to prevent overflow when calculating pointer offsets.
base_offset = token_idx.to(tl.int64) * vocab_size
max_prob = -float('inf')
best_token_id = 0
for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE):
block_end = min(block_start + BLOCK_SIZE, vocab_size)
vocab_offset = tl.arange(0, BLOCK_SIZE)
mask = vocab_offset < block_end - block_start
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
prob = tl.load(
target_probs_ptr + base_offset + block_start + vocab_offset,
mask=(mask & (vocab_offset + block_start != draft_token_id)),
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + base_offset + block_start +
vocab_offset,
mask=mask,
other=0)
target_prob = tl.load(target_probs_ptr + base_offset +
block_start + vocab_offset,
mask=mask,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)| q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset, | ||
| mask=mask, | ||
| other=float("-inf")) |
There was a problem hiding this comment.
For consistency and to prevent potential 32-bit integer overflows, the offset calculation for q_ptr should also use 64-bit integers. While req_idx is likely small, multiplying by a large vocab_size could still pose a risk on some platforms or with very large batches. Using tl.int64 ensures the calculation is safe.
| q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset, | |
| mask=mask, | |
| other=float("-inf")) | |
| q = tl.load(q_ptr + req_idx.to(tl.int64) * vocab_size + block_start + vocab_offset, | |
| mask=mask, | |
| other=float("-inf")) |
There was a problem hiding this comment.
Code Review
This pull request aims to fix a Triton Error related to memory size and pointer values in sample_recovered_tokens_kernel. The approach of iterating over the vocabulary in blocks is a good solution for the private memory size issue. However, the fix seems incomplete as it doesn't address the potential for 32-bit integer overflow in pointer offset calculations. I've added a suggestion to explicitly use 64-bit integers for these calculations to make the fix robust. The other changes in the pull request are correct.
|
|
||
| max_prob = -float('inf') | ||
| best_token_id = 0 | ||
|
|
||
| for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE): | ||
| block_end = min(block_start + BLOCK_SIZE, vocab_size) | ||
|
|
||
| vocab_offset = tl.arange(0, BLOCK_SIZE) | ||
| mask = vocab_offset < block_end - block_start | ||
|
|
||
| if NO_DRAFT_PROBS: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
| prob = tl.load( | ||
| target_probs_ptr + (start_idx + pos) * vocab_size + | ||
| block_start + vocab_offset, | ||
| mask=(mask & (vocab_offset + block_start != draft_token_id)), | ||
| other=0) | ||
|
|
||
| else: | ||
| draft_prob = tl.load(draft_probs_ptr + | ||
| (start_idx + pos) * vocab_size + block_start + | ||
| vocab_offset, | ||
| mask=mask, | ||
| other=0) | ||
| target_prob = tl.load(target_probs_ptr + | ||
| (start_idx + pos) * vocab_size + | ||
| block_start + vocab_offset, | ||
| mask=mask, | ||
| other=0) | ||
| prob = tl.maximum(target_prob - draft_prob, 0) | ||
|
|
||
| # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because | ||
| # `tl.argmax` will select the maximum value. | ||
|
|
||
| q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset, | ||
| mask=mask, | ||
| other=float("-inf")) |
There was a problem hiding this comment.
The PR title and description mention fixing RuntimeError: Triton Error [MACA]: memory size or pointer value too large to fit in 32 bit. While iterating over the vocabulary in blocks fixes the "private memory size too large" issue, the "pointer value too large" issue can still occur.
The pointer offset calculations like (start_idx + pos) * vocab_size and req_idx * vocab_size can overflow a 32-bit integer if the total number of tokens or the vocabulary size is large, leading to the error.
This suggestion ensures 64-bit arithmetic for offset calculations by explicitly casting to tl.int64 before multiplication. This prevents potential overflows and fully resolves the reported error, making the kernel more robust.
token_idx = start_idx + pos
# Ensure 64-bit arithmetic for large offsets to prevent overflow.
probs_base_offset = token_idx.to(tl.int64) * vocab_size
q_base_offset = req_idx.to(tl.int64) * vocab_size
max_prob = -float('inf')
best_token_id = 0
for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE):
block_end = min(block_start + BLOCK_SIZE, vocab_size)
vocab_offset = tl.arange(0, BLOCK_SIZE)
mask = vocab_offset < block_end - block_start
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
prob = tl.load(
target_probs_ptr + probs_base_offset + block_start +
vocab_offset,
mask=(mask & (vocab_offset + block_start != draft_token_id)),
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + probs_base_offset +
block_start + vocab_offset,
mask=mask,
other=0)
target_prob = tl.load(target_probs_ptr + probs_base_offset +
block_start + vocab_offset,
mask=mask,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + q_base_offset + block_start + vocab_offset,
mask=mask,
other=float("-inf"))Signed-off-by: Xin Li <xin.li@metax-tech.com>
#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <xin.li@metax-tech.com> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <xin.li@metax-tech.com> * [chore] change file directory Signed-off-by: Xin Li <xin.li@metax-tech.com> --------- Signed-off-by: Xin Li <xin.li@metax-tech.com> Co-authored-by: Xin Li <xin.li@metax-tech.com>
…l` (#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <xin.li@metax-tech.com> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <xin.li@metax-tech.com> * [chore] change file directory Signed-off-by: Xin Li <xin.li@metax-tech.com> --------- Signed-off-by: Xin Li <xin.li@metax-tech.com> Co-authored-by: Xin Li <xin.li@metax-tech.com> Signed-off-by: leex404 <lixin1620@gmail.com>
#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <xin.li@metax-tech.com> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <xin.li@metax-tech.com> * [chore] change file directory Signed-off-by: Xin Li <xin.li@metax-tech.com> --------- Signed-off-by: Xin Li <xin.li@metax-tech.com> Co-authored-by: Xin Li <xin.li@metax-tech.com>
#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <xin.li@metax-tech.com> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <xin.li@metax-tech.com> * [chore] change file directory Signed-off-by: Xin Li <xin.li@metax-tech.com> --------- Signed-off-by: Xin Li <xin.li@metax-tech.com> Co-authored-by: Xin Li <xin.li@metax-tech.com>
#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <xin.li@metax-tech.com> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <xin.li@metax-tech.com> * [chore] change file directory Signed-off-by: Xin Li <xin.li@metax-tech.com> --------- Signed-off-by: Xin Li <xin.li@metax-tech.com> Co-authored-by: Xin Li <xin.li@metax-tech.com>
…l` (#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <xin.li@metax-tech.com> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <xin.li@metax-tech.com> * [chore] change file directory Signed-off-by: Xin Li <xin.li@metax-tech.com> --------- Signed-off-by: Xin Li <xin.li@metax-tech.com> Co-authored-by: Xin Li <xin.li@metax-tech.com> Signed-off-by: leex404 <lixin1620@gmail.com>
* support platform and remove kernel copy Signed-off-by: Hank <hcc.mayday@gmail.com> * update pre-commit Signed-off-by: Hank <hcc.mayday@gmail.com> * update version and requirements Signed-off-by: Hank <hcc.mayday@gmail.com> * update flashinfer Signed-off-by: Hank <hcc.mayday@gmail.com> * update build requirements Signed-off-by: Hank <hcc.mayday@gmail.com> * update attention backends Signed-off-by: Hank <hcc.mayday@gmail.com> * update patch Signed-off-by: Hank <hcc.mayday@gmail.com> * update quant_method Signed-off-by: Hank <hcc.mayday@gmail.com> * update fuse_moe (todo: fix mypy) Signed-off-by: Hank <hcc.mayday@gmail.com> * update `deepseek_v2.py`(todo: fix indexer kernel) Signed-off-by: Hank <hcc.mayday@gmail.com> * [feat] support bf16 cp_gather_indexer_k_cache kernel Signed-off-by: Xin Li <lixin1620@gmail.com> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: leex404 <lixin1620@gmail.com> * [feat] add topk logits ops Signed-off-by: leex404 <lixin1620@gmail.com> * [fix] private memory size too large in `sample_recovered_tokens_kernel` (#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <xin.li@metax-tech.com> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <xin.li@metax-tech.com> * [chore] change file directory Signed-off-by: Xin Li <xin.li@metax-tech.com> --------- Signed-off-by: Xin Li <xin.li@metax-tech.com> Co-authored-by: Xin Li <xin.li@metax-tech.com> Signed-off-by: leex404 <lixin1620@gmail.com> * [fix] fix missing topk logits custom ops definition Signed-off-by: leex404 <lixin1620@gmail.com> * [fix] add custom gptq_shuffle ops Signed-off-by: leex404 <lixin1620@gmail.com> * [fix] fix compile error Signed-off-by: leex404 <lixin1620@gmail.com> * platform config update Signed-off-by: Hank <hcc.mayday@gmail.com> * update qwen2.5_vl model Signed-off-by: Hank <hcc.mayday@gmail.com> * [fix] fix torch not found maca device Signed-off-by: leex404 <lixin1620@gmail.com> * remove hotfixes patch for torch2.8 Signed-off-by: Hank <hcc.mayday@gmail.com> * remove needless patch related: vllm-project/vllm/pull/27322 Signed-off-by: Hank <hcc.mayday@gmail.com> * [feat] topk_softmax support renormalize and bf16 Signed-off-by: leex404 <lixin1620@gmail.com> * [fix] update fused_moe to fit v0.11.1 Signed-off-by: leex404 <lixin1620@gmail.com> * [fix] fix fused moe config log missing Signed-off-by: leex404 <lixin1620@gmail.com> * use flash_attn as vit attn backend on qwen_vl Signed-off-by: Hank <hcc.mayday@gmail.com> * update quant_conf registry Signed-off-by: Hank <hcc.mayday@gmail.com> * fix and apply latest pre-commit of v0.11.1 Signed-off-by: Hank <hcc.mayday@gmail.com> * [feat] Keep all AITER kernels in _aiter_ops Signed-off-by: leex404 <lixin1620@gmail.com> * fix pre-commit on type casting Signed-off-by: Hank <hcc.mayday@gmail.com> * [fix] fix DeepSeek import error Signed-off-by: leex404 <lixin1620@gmail.com> * [feat] update deepseek_v2 to fit v0.11.1 Signed-off-by: leex404 <lixin1620@gmail.com> --------- Signed-off-by: Hank <hcc.mayday@gmail.com> Signed-off-by: Xin Li <lixin1620@gmail.com> Signed-off-by: leex404 <lixin1620@gmail.com> Co-authored-by: Xin Li <xin.li@metax-tech.com> Co-authored-by: leex404 <lixin1620@gmail.com> Co-authored-by: leex404 <42941760+leex404@users.noreply.github.com>
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
fix error of
RuntimeError: Triton Error [MACA]: memory size or pointer value too large to fit in 32 bitTest Plan
Test Result
(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.